import matplotlib.pyplot as pyplot
import numpy
import svgutils as SVG
from matplotlib.colors import LinearSegmentedColormap,ListedColormap
from matplotlib import rc,rcParams, lines
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm,LogNorm
from matplotlib.ticker import MultipleLocator, ScalarFormatter,FuncFormatter,FormatStrFormatter
from scipy.optimize import fsolve as fsolve
import scipy.optimize as optimise
from matplotlib import gridspec
import scipy.interpolate as interp
from matplotlib.patches import ConnectionPatch,FancyBboxPatch,Rectangle
import os
from jqc import jqc_plot
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

Wavelength=1064 #nm

T0 = 0.7810
T90 = 0.99947

WattsPerVolt = 0.985

jqc_plot.plot_style("normal")
directory = os.path.dirname(os.path.abspath(__file__))

File_path_root =directory+r"\ACStark"+str(Wavelength)+"nm\\"

colour_dict_twk_blue = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,124.0/255.0,124.0/255.0),
            (0.66,0.0,0.0),
            (1.0,0.0,0.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,154.0/255.0,154.0/255.0),
            (0.66,70.0/255.0,70.0/255.0),
            (1.0,32.0/255.0,32.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,148.0/255.0,148/255.0),
            (0.66,127.0/255.0,127.0/255.0),
            (1.0,58.0/255.0,58.0/255.0)]
}

colour_dict_twk_red = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,229.0/255.0,229.0/255.0),
            (0.66,214.0/255.0,214.0/255.0),
            (1.0,170.0/255.0,170.0/255.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,177.0/255.0,177.0/255.0),
            (0.66,120.0/255.0,120.0/255.0),
            (1.0,43.0/255.0,43.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,145.0/255.0,145/255.0),
            (0.66,122.0/255.0,122.0/255.0),
            (1.0,74.0/255.0,74.0/255.0)]

}

JQC = {'red'     :(198.0/255.0, 62.0/255.0, 98.0/255.0), \
       'blue'    :(0.0/255.0, 70.0/255.0, 127.0/255.0), \
       'purple'  :(126.0/255.0, 29.0/255.0, 123.0/255.0), \
       'sand'  :(244./255.0, 234./255.0, 168./255.0), \
       'grayblue'  :(212./255.0, 213./255.0, 220./255.0), \
       'green'   :(45.0/255.0, 159.0/255.0, 60.0/255.0)}

colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                            colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)

colour_dict_twk_red_alpha = colour_dict_twk_red.copy()
colour_dict_twk_red_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_red = LinearSegmentedColormap("RbCs_map_tweak_red",
                                            colour_dict_twk_red_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_red)


def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
        norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,legend=False):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  #to check for numerical input this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth)

    ax = pyplot.gca()
    ax.add_collection(lc)

    return lc


Data_Exp0P = numpy.genfromtxt(File_path_root+"\\Experimental Data\\Beta= 0\\MF5.csv",delimiter=',')[:,:]

Data_Exp0S =  numpy.genfromtxt(File_path_root+"\\Experimental Data\\Beta= 0\\MF6.csv",delimiter=',')

Data_Exp0S[:,0] = (2*Data_Exp0S[:,0]*WattsPerVolt*T0)/(numpy.pi*174e-4**2)
Data_Exp0P[:,0] = (2*Data_Exp0P[:,0]*WattsPerVolt*T0)/(numpy.pi*174e-4**2)

Data_Exp90 =numpy.genfromtxt(File_path_root+"\\Experimental Data\\Beta= 90\\MF5.csv",delimiter=',')
Data_Exp90[:,0] = (2*Data_Exp90[:,0]*WattsPerVolt*T90)/(numpy.pi*174e-4**2)

File_path = File_path_root +"Beta = 0\\"

Data = numpy.genfromtxt(File_path+r"lines.dat")

DT_int = Data[:,0].copy()
Lines = Data[:,1:].copy()

Data = numpy.genfromtxt(File_path+r"sigma+.dat")
Lines_int1 = Data[:,1:].copy()

Data = numpy.genfromtxt(File_path+r"sigma-.dat")
Lines_int2 = Data[:,1:].copy()
Lines_int_S = Lines_int1+Lines_int2

Data = numpy.genfromtxt(File_path+r"pi.dat")
Lines_int_P = Data[:,1:].copy()


locs_S = numpy.where(Lines_int_S != 0)
locs_P = numpy.where(Lines_int_P != 0)

Lines_fitting_S =Lines[locs_S[0][2::7],locs_S[1][2::7]].copy()
Lines_fitting_P =Lines[locs_P[0][::3],locs_P[1][::3]].copy()

Lines_fitting_S = numpy.tile(Lines_fitting_S,[Data_Exp0S.shape[0],1])
Lines_fitting_P = numpy.tile(Lines_fitting_P,[Data_Exp0P.shape[0],1])
Lines_fitting = numpy.vstack((Lines_fitting_P,Lines_fitting_S))

Data_Exp0 = numpy.vstack((Data_Exp0P,Data_Exp0S))



############################ PLOTTING ###########################


fig = pyplot.figure()
grid =gridspec.GridSpec(1,7,width_ratios=[1.0,0.03,1,0.03,0.07,0.02,0.07])
vmin=1e-2
vmax=1
cbar_ax1 = fig.add_subplot(grid[6])
cbar_ax2 = fig.add_subplot(grid[4])


########################### BETA = 0  ########################

Beta0 = fig.add_subplot(grid[2])

#Beta0.plot(DT_int*1e-3,980.231+poly[0]*DT_int*1e-3)
Beta0.errorbar(Data_Exp0[:,0]*1e-3,Data_Exp0[:,1],yerr=Data_Exp0[:,2],fmt='o',
                color='k')


for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	pyplot.plot(DT_int*1e-3,Lines[:,i]*1e-3,
                color=(244.0/255.0, 234.0/255.0, 168.0/255.0), zorder=0)

min_int = 0.0 # intensity is >0
max_int = 0.1 #set the normalisation to be maxed at the maximum of int (for colour)

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-3,Lines[:,i]*1e-3,Lines_int_S[:,i],
                    cmap='RbCs_map_tweak_red',norm=LogNorm(vmin,vmax=1),
                    linewidth=2.0)

cbar1 = pyplot.colorbar(cl,cax=cbar_ax1, pad=-.08)
 #only include a colourbar on the first or last line else you get overlapping

cbar1.set_label('Relative Transition Strength')
cbar1.set_ticks([1e-3, 1e-2, 1e-1, 1e-0])
cbar1.ax.set_title("$y$",color=JQC['red'])


min_int = 0.0 # intensity is >0
max_int = 0.1 #set the normalisation to be maxed at the maximum of int (for colour)

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-3,Lines[:,i]*1e-3,Lines_int_P[:,i],
                    cmap='RbCs_map_tweak_blue',norm=LogNorm(vmin,vmax=1),
                    linewidth=2.0)

cbar2 = pyplot.colorbar(cl,cax=cbar_ax2) #only include a colourbar on the first or last line else you get overlapping
cbar2.set_ticks([])
cbar2.ax.set_title("$z$",color=JQC['blue'])

Beta0.set_xlabel(r"Intensity (kW$\,$cm$^{-2}$)")

Beta0.set_xlim(0,1e-3*numpy.amax(DT_int))
Beta0.set_ylim(979.8,980.6)
#Beta0.set_ylabel("Transition Frequency (MHz)")
Offset = 980
Beta0.ticklabel_format(axis='y',useOffset=Offset,style='plain')
Beta0.yaxis.offsetText.set_visible(False)
#Beta0.annotate("+"+str("%.0f"%Offset)+" MHz",xy=(0.01,1.01),
#                xycoords="axes fraction")

im = pyplot.imread(directory+"\\Insets\\Beta0_axis.png")
insetax1 = inset_axes(Beta0,width="100%",height="100%",
                   bbox_to_anchor=(.1, 0.0, .55, .55),
                   bbox_transform=Beta0.transAxes)
insetax1.imshow(im)
insetax1.axis('off')

Beta0.text(0.05,0.05,"(b)",fontsize=20,transform=Beta0.transAxes)
Beta0.tick_params(axis='y',which='both',length=0)

########### BETA = 90 ################
Beta90 = fig.add_subplot(grid[0],sharey=Beta0)

File_path = File_path_root +"Beta = 90\\"

Data = numpy.genfromtxt(File_path+r"lines.dat")
DT_int = Data[:,0].copy()
Lines = Data[:,1:].copy()


Beta90.errorbar(Data_Exp90[:,0]*1e-3,Data_Exp90[:,1],yerr=Data_Exp90[:,2],
                fmt='o',color='k')

Data = numpy.genfromtxt(File_path+r"sigma+.dat")
Lines_int1 = Data[:,1:].copy()

Data = numpy.genfromtxt(File_path+r"sigma-.dat")
Lines_int2 = Data[:,1:].copy()
Lines_int_S = Lines_int1+Lines_int2

Data = numpy.genfromtxt(File_path+r"pi.dat")
Lines_int_P = Data[:,1:].copy()

locs_S = numpy.where(Lines_int_S != 0)
locs_P = numpy.where(Lines_int_P != 0)

Lines_fitting_S =Lines[locs_S[0][2::7],locs_S[1][2::7]].copy()
Lines_fitting_P =Lines[locs_P[0][::3],locs_P[1][::3]].copy()

Lines_fitting = numpy.tile(Lines_fitting_P,[Data_Exp90.shape[0],1])

Data_Exp0 = numpy.vstack((Data_Exp0P,Data_Exp0S))



####################################################

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	pyplot.plot(DT_int*1e-3,Lines[:,i]*1e-3,
                color=(244.0/255.0, 234.0/255.0, 168.0/255.0), zorder=0)

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-3,Lines[:,i]*1e-3,Lines_int_S[:,i],
                cmap='RbCs_map_tweak_red',norm=LogNorm(vmin,vmax=1),
                linewidth=2.0)


for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-3,Lines[:,i]*1e-3,Lines_int_P[:,i],
                cmap='RbCs_map_tweak_blue',norm=LogNorm(vmin,vmax=1),
                linewidth=2.0)

#pyplot.setp(Beta90.get_yticklabels(),visible=False)
Beta90.set_ylim(979.8,980.6)
Beta90.set_ylabel("Transition Frequency (MHz)")
Offset = 980
Beta90.ticklabel_format(axis='y',useOffset=Offset,style='plain')
Beta90.yaxis.offsetText.set_visible(False)

im = pyplot.imread(directory+"\\Insets\\Beta90_axis.png")
insetax2 = inset_axes(Beta90,width="100%",height="100%",
                   bbox_to_anchor=(.1, 0., .55, .55),
                   bbox_transform=Beta90.transAxes)
insetax2.imshow(im)
insetax2.axis('off')


Beta90.set_xlim(0,11)
Beta0.set_xlim(0,10)
Beta90.set_xlabel(r"Intensity (kW$\,$cm$^{-2}$)")
Beta90.annotate("+"+str("%.0f"%Offset)+" MHz",xy=(0.01,1.01),
xycoords="axes fraction")
Beta90.text(0.05,0.05,"(a)",fontsize=20,transform=Beta90.transAxes)

#Beta90.text(0.1,0.1,r"$\beta = 90^\circ$",transform=Beta90.transAxes)
#Beta0.text(0.1,0.1,r"$\beta = 0^\circ$",transform=Beta0.transAxes)


cbar_ax2.tick_params(axis='y',which='both',length=0)


pyplot.tight_layout()
pyplot.subplots_adjust(wspace=0.)
pyplot.setp(Beta0.get_yticklabels(),visible=False)
pyplot.savefig(File_path_root+"\\1064.pdf",dpi=1500)
pyplot.savefig(File_path_root+"\\1064.png")


pyplot.show()
